-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix support for optional inputs in model.fit #21548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @neo-alex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses a critical issue where model.fit in Keras fails to properly handle optional inputs provided as None, unlike model.predict_on_batch and model.train_on_batch. My changes ensure that model.fit can correctly process models with optional None inputs across TensorFlow, JAX, and PyTorch backends, aligning its behavior with other training utilities.
Highlights
- Enabled model.fit with Optional None Inputs: The primary goal of this PR is to allow model.fit to successfully train models where certain inputs are optionally provided as None, which was previously causing failures.
- Enhanced tree.map_structure Handling of None: I've updated the tree.map_structure utility across Keras to include a none_is_leaf parameter. Setting this to False ensures that None values are traversed as part of the structure rather than being treated as terminal leaves, which is crucial for handling optional inputs.
- TensorFlow Backend Adaptations: For the TensorFlow backend, I've introduced explicit conversions of tf.experimental.Optional instances to None within the train_step and ensured that None values from generators are correctly converted to tf.experimental.Optional.empty(None) when constructing tf.data.Dataset objects.
- Improved Data Adapter Robustness: Various data adapters (e.g., ArrayDataAdapter, GeneratorDataAdapter, TFDatasetAdapter, TorchDataLoaderAdapter, GrainDatasetAdapter) have been modified to correctly handle None values when calculating sample cardinalities, creating tensor specifications, or iterating through batches, preventing errors related to missing input shapes.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for optional inputs in model.fit
by introducing a none_is_leaf
parameter to tree.map_structure
. This allows None
values, which represent optional inputs, to be correctly handled across various data adapters and backends. The changes are logical and consistently applied. However, I've found a potential issue where the logic to handle TensorFlow's Optional
type is missing from test_step
and predict_step
, which could cause problems during evaluation and prediction.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21548 +/- ##
==========================================
- Coverage 82.72% 82.71% -0.01%
==========================================
Files 567 567
Lines 56264 56476 +212
Branches 8797 8829 +32
==========================================
+ Hits 46544 46716 +172
- Misses 7562 7593 +31
- Partials 2158 2167 +9
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
I am unsure where would be the best place for it... model_test maybe?).
Yes, this would be the place to test. Ideally, it would test fit
, predict
and evaluate
.
Also, ideally, this would be tested in *_data_adapter_test.py
to cover all cases.
Taking a step back, is the goal to handle the case when in the dataset passed to fit
"input2" is always None
? Or sometimes None
sometimes not None
. Right now it looks like it's only supporting the latter (always None
).
@@ -32,6 +32,8 @@ def get_tf_dataset(self): | |||
from keras.src.utils.module_utils import tensorflow as tf | |||
|
|||
def convert_to_tf(x, spec): | |||
if isinstance(spec, tf.OptionalSpec): | |||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you just return tf.experimental.Optional.empty(None)
here are remove lines 55-62?
Or tf.experimental.Optional.empty(None) is x is None else x
?
Either way, lines 55-62 should move here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately no (this is what I tried first): indeed, an error is then raised by tree.map_structure
on line 63 because batch
and self._output_signature
do not have the same structure (more specifically: None leaves in batch
do not match None leaves in self._output_signature
, which have tf.OptionalSpec
instead). This is why I had to convert None leaves in batch
first, in lines 55-62 - please let me know if you find a more elegant solution to this issue though (I am also not a fan of having 2 map.structure
calls in a row if it is avoidable).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Erratum: sorry, I got confused in my own tests (there is actually no issue using tree.map_structure
with None leaves on one structure and tf.OptionalSpec on another, as long as none_is_leaf=True
- which is the default). So you are right and I simplified the logic according to your comment in this commit.
keras/src/tree/tree_api.py
Outdated
@@ -179,6 +179,7 @@ def map_structure(func, *structures): | |||
Args: | |||
func: A callable that accepts as many arguments as there are structures. | |||
*structures: Arbitrarily nested structures of the same layout. | |||
none_is_leaf: If True, None is treated as a leaf. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add more details here? The name none_is_leaf
is pretty unintuitive actually. Basically, say something like:
none_is_leaf=True
causes func
to be called on None
leaves, and none_is_leaf=False
means None
s are not passed to func
and are returned in the output directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I improved its docstring accordingly in this commit. By the way, I agree that the name none_is_leaf
is not the most intuitive, but it is used consistently throughout the underlying optree library (e.g. here), so I kept the same one.
keras/src/tree/dmtree_impl.py
Outdated
if not all(s is None for s in args): | ||
raise ValueError( | ||
"Structure mismatch: some arguments are None, others " | ||
"are not." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issues while running map_structure
can be hard to debug. Any bit of context can help.
Can you add args
?
raise ValueError(
"Structure mismatch: some arguments are None, others "
f"are not: {args}."
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, done in this commit.
def _convert_optional_to_none(self, x): | ||
# Convert TF Optional implementations to None | ||
return tree.map_structure( | ||
lambda i: None if isinstance(i, tf.experimental.Optional) else i, x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you also need to do i.get_value() if i.has_value() else None
? So that you support both the None
and not None
cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you are probably right, I will double-check (see also my reply below to your "taking a step back" comment wrt. mixing None
and not None
cases).
@@ -199,6 +203,8 @@ def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): | |||
""" | |||
from keras.src.utils.module_utils import tensorflow as tf | |||
|
|||
if keras_tensor is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this actually ever happen?
My assumption was that this would need to handle non-None inputs that have optional=True
on them (this might require some changes), and then create a tf.OptionalSpec(<the actual tensorspec for the KerasTensor per the code below>)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does actually happen, even if the reason is not intuitive: your assumption makes a lot of sense (ideally we would like optional inputs to be represented by KerasTensor
with optional=True
like in the model), unfortunately all the code in data_adapters is independent from the model, and the data spec is solely inferred from the first batches of received data (typically here)... which seems indeed a bit brittle and prone to some "hidden" constraints for the first batches of the dataset (e.g. see this error message).
Since it is not possible to infer a proper KerasTensor
just from a received None
value, the trick I am using is to keep it as None
(by using the newly introduced none_is_leaf=False
inside get_keras_tensor_spec
), which explains then that the line of code you mention is actually needed.
Thank you very much @hertschuh for your insightful review! To answer your "taking a step back" comment, for now the goal is at least to enable optional inputs in model fit/evaluate/predict when they are always or never I will continue to investigate and see if a solution for mixed values ( |
You're right, the data spec and and the model inputs are disconnected, which, as you point out, is the source of a number of shortcomings. It might not be possible to mix |
Here is an example of a model with 2 inputs, the second one being optional:
With this definition, the model can be called in Jax/TF/Torch without issue, even when input2 is None:
It is even possible to train on a batch when input2 is None:
However, doing the same with model.fit API is currently failing on all backends:
The purpose of this PR is to fix this issue (on Jax/TF/Torch), so that the last code block above becomes possible (btw. I could add 1 or 2 unit tests along those lines to demonstrate the fix but I am unsure where would be the best place for it... model_test maybe?).